scikit-learn 1.0でstableになったHistGradientBoostingClassifierを使ってみる

scikit-learn 1.0でstableになったHistGradientBoostingClassifierを使ってみる

Clock Icon2021.09.27

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

データアナリティクス事業本部の鈴木です。

はじめに

先日、scikit-learn 1.0がリリースされました。

Release Highlights for scikit-learn 1.0

リリースハイライトとリリースノートに修正点が記載されていますが、その中でHistogram-based Gradient Boosting Models are now stableのセクションは、大きな注目点の一つでしょう。

Histogram-based Gradient Boosting Models are now stable

上記セクションによると、以下の2つのヒストグラムベースの勾配ブースティングモデルがstableとなりました。

これらのモデルは、以前のバージョンではexperimentalという扱いではあったものの、性能が良く、欠損値のサポートや、scikit-learnの機械学習パイプライン互換のAPIを持っており、とても使い勝手が良かったため、注目されていた方が多いのではないでしょうか。

scikit-learn 1.0になったことで使い心地に大きな違いはなさそうですが、

例えばscikit-learn 0.24では

from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier

のようにsklearn.experimentalモジュールからもインポートする必要がありましたが、

scikit-learn 1.0では、

from sklearn.ensemble import HistGradientBoostingClassifier

のようにシンプルにインポートできるようになりました。

今回は、scikit-learn 1.0のHistGradientBoostingClassifierを使って、分類問題を解いてみることで、使い方をおさらいしたいと思います。

準備

検証した環境

  • コンテナ:jupyter/datascience-notebook
  • scikit-learn:1.0

私の使っているイメージだと、scikit-learn 0.24がインストールされているので、pipでscikit-learn 1.0をインストールします。

!pip install --upgrade scikit-learn
## ...
## Successfully installed scikit-learn-1.0

import sklearn
print(sklearn.__version__)
## 1.0

データの準備

今回はseaborn-dataのpenguinsを利用します。

penguinsは、以下のpalmerpenguinsをもとにしたデータです。

allisonhorst/palmerpenguins: A great intro dataset for data exploration & visualization (alternative to iris).

以下のようにしてデータを作成しておきます。

import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split

# データのロード
df = sns.load_dataset('penguins')

# カラムの型と欠損値の有無を確認
df.info()
## <class 'pandas.core.frame.DataFrame'>
## RangeIndex: 344 entries, 0 to 343
## Data columns (total 7 columns):
##  #   Column             Non-Null Count  Dtype  
## ---  ------             --------------  -----  
##  0   species            344 non-null    object 
##  1   island             344 non-null    object 
##  2   bill_length_mm     342 non-null    float64
##  3   bill_depth_mm      342 non-null    float64
##  4   flipper_length_mm  342 non-null    float64
##  5   body_mass_g        342 non-null    float64
##  6   sex                333 non-null    object 
## dtypes: float64(4), object(3)
## memory usage: 18.9+ KB

df.isnull().sum()
## species               0
## island                0
## bill_length_mm        2
## bill_depth_mm         2
## flipper_length_mm     2
## body_mass_g           2
## sex                  11
## dtype: int64

# 特徴と推定対象に分離
X = df.drop("species", axis=1)
y = df["species"]

# 訓練データとテストデータに分離
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

今回は、ペンギンの種類(speciesカラムの値)を推定したいとします。

特徴には量的変数とカテゴリー変数があり、いくつかのカラムには欠損値があることが分かります。

使ってみる

HistGradientBoostingClassifierを使って、ペンギンの種類を推定してみましょう。

機械学習パイプライン内では以下の前処理を行います。

  • カテゴリ変数はOne-hot表現にする。

また、learning_rateを指定してみます。パラメータは基本的にキーワード引数で渡す必要があります。

from sklearn.compose import ColumnTransformer
from sklearn.compose import make_column_selector
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder

preprocessor = ColumnTransformer(
    transformers=[
        ('num', 'passthrough', make_column_selector(dtype_include=np.number)),
        ('cat1', OneHotEncoder(handle_unknown='ignore'), make_column_selector(dtype_include=object))])

pipe = Pipeline([("preprocessor", preprocessor),  
                 ("Classifier", HistGradientBoostingClassifier(learning_rate=0.15))])

作成したパイプラインを確認します。

from sklearn import set_config
set_config(display='diagram')   
pipe

作成したパイプラインのダイアグラム

最後に参考までにscoreを実行し、推論ができることを確認しておきます。

# 学習
pipe.fit(X_train, y_train)

# 推論結果のaccuracyを計算する。
pipe.score(X_test, y_test)
## 0.9767441860465116

最後に

scikit-learn 1.0のインストール方法と、stableとなったHistGradientBoostingClassifierの使い方をおさらいしました。

強力な勾配ブースティングモデルでありながら、pipelineモジュールとの親和性が高いので、scikit-learnで構築した機械学習パイプラインと合わせて、どんどん活用していきたいですね。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.